import importlib.util
import logging
import os
from pathlib import Path
from typing import TYPE_CHECKING

import mlop
from mlop.util import import_lib

if TYPE_CHECKING:
    from transformers import modelcard
    from transformers.integrations.integration_utils import (
        INTEGRATION_TO_CALLBACK,
        get_reporting_integration_callbacks,
        rewrite_logs,
    )
    from transformers.trainer import Trainer
    from transformers.trainer_callback import TrainerCallback
else:
    modelcard = getattr(import_lib("transformers"), "modelcard", object)
    INTEGRATION_TO_CALLBACK = getattr(
        import_lib("transformers.integrations.integration_utils"),
        "INTEGRATION_TO_CALLBACK",
        dict(),
    )
    rewrite_logs = getattr(
        import_lib("transformers.integrations.integration_utils"),
        "rewrite_logs",
        object,
    )
    get_reporting_integration_callbacks = getattr(
        import_lib("transformers.integrations.integration_utils"),
        "get_reporting_integration_callbacks",
        object,
    )
    Trainer = getattr(import_lib("transformers.trainer"), "Trainer", object)
    TrainerCallback = getattr(
        import_lib("transformers.trainer_callback"), "TrainerCallback", object
    )


logger = logging.getLogger(f"{__name__.split('.')[0]}")
tag = "Transformers"


class MLOPCallback(TrainerCallback):
    def __init__(self):
        if not importlib.util.find_spec("mlop"):
            logger.error(f"{tag}: mlop is not installed")
            return None
        self.op = None
        self._initialized = False
        self._log_model = os.getenv("LOG_MODEL", None)

    def setup(self, args, state, model, **kwargs):
        self._initialized = True

        if state.is_world_process_zero:
            conf = {**args.to_dict()}

            if hasattr(model, "config") and model.config is not None:
                model_config = (
                    model.config
                    if isinstance(model.config, dict)
                    else model.config.to_dict()
                )
                conf = {**model_config, **conf}
            if hasattr(model, "peft_config") and model.peft_config is not None:
                peft_config = model.peft_config
                conf = {**{"peft_config": peft_config}, **conf}

            init_args = {}
            if state.trial_name is not None:
                init_args["name"] = state.trial_name
                init_args["group"] = args.run_name
            elif args.run_name is not None:
                init_args["name"] = args.run_name
                if args.run_name == args.output_dir:
                    logger.warning(
                        f"{tag}: run_name was set to the same value as TrainingArguments.output_dir."
                    )

            try:
                init_args["model"] = {"params": model.num_parameters()}
            except Exception as e:
                logger.error("%s: error getting model parameters: %s", tag, e)

            if mlop.ops and len(mlop.ops) > 0:
                self.op = mlop.ops[-1]
            else:
                if "project" not in init_args:
                    init_args["project"] = "transformers"
                self.op = mlop.init(**init_args, config=conf)

            self.op.watch(
                model, **{"freq": state.logging_steps} if state.logging_steps else {}
            )

            if self._log_model:
                self.op.log({"info": mlop.Text(model.__repr__())})

                # TODO: add badge
                modelcard.AUTOGENERATED_TRAINER_COMMENT += "\n mlop.ai"

    def on_train_begin(self, args, state, control, model=None, **kwargs):
        if state.is_hyper_param_search:
            self.op.finish()
            self._initialized = False
            args.run_name = None
        if not self._initialized:
            self.setup(args, state, model, **kwargs)

    def on_train_end(
        self, args, state, control, model=None, processing_class=None, **kwargs
    ):
        if not self.op:
            return

        if self._log_model and self._initialized and state.is_world_process_zero:
            fake = Trainer(
                args=args,
                model=model,
                processing_class=processing_class,
                eval_dataset=["fake"],
            )
            fake.save_model(os.path.join(self.op.settings.get_dir(), "final"))
            self.op.log(
                {
                    "info": mlop.Text(model.__repr__()),
                    **{
                        f"model/{f.name}": mlop.Artifact(
                            data=os.path.abspath(f),
                            metadata={
                                f"eval/{args.metric_for_best_model}": state.best_metric,
                                "train/total_floss": state.total_floss,
                                "params": model.num_parameters(),
                            }
                            if state.best_metric
                            else {},
                        )
                        for f in Path(
                            os.path.join(self.op.settings.get_dir(), "final")
                        ).glob("**/*")
                        if f.is_file()
                    },
                }
            )

    def on_log(self, args, state, control, model=None, logs=None, **kwargs):
        if not self.op:
            return
        if not self._initialized:
            self.setup(args, state, model, **kwargs)

        if state.is_world_process_zero:
            logs = rewrite_logs(logs)
            self.op.log({**logs, "train/global_step": state.global_step})

    def on_save(self, args, state, control, **kwargs):
        if self._log_model and self._initialized and state.is_world_process_zero:
            self.op.log(
                {
                    f"model/{f.name}": mlop.Artifact(
                        data=os.path.abspath(f),
                        metadata={"info": mlop.Text(f.read_text())},
                    )
                    for f in Path(
                        os.path.join(args.output_dir, f"checkpoint-{state.global_step}")
                    ).glob("**/*")
                    if f.is_file()
                }
            )

    def on_predict(self, args, state, control, metrics, **kwargs):
        self.on_log(args, state, control, logs=metrics, **kwargs)


# TODO: remove patch for transformers
INTEGRATION_TO_CALLBACK["mlop"] = MLOPCallback()
